# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import gc
import shutil
from pathlib import Path
from typing import Dict

import torch
from tqdm import tqdm

"""
Sample usage:

```bash
python -m scripts.convert_checkpoint -h

python -m scripts.convert_checkpoint converted
```
"""


def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
    converted = {}
    converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"].to(dtype)
    converted["lm_head.weight"] = state_dict["output.weight"].to(dtype)
    converted["transformer.ln_f.scale"] = state_dict["norm.weight"].to(dtype)

    for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
        # attention
        # the wq, wk, wv from the FB model are stacked in our model as c_attn
        converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat(
            (
                state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype),
                state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype),
                state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype),
            )
        )
        converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[
            f"layers.{layer_idx}.attention.wo.weight"
        ].to(dtype)
        # mlp
        converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[
            f"layers.{layer_idx}.feed_forward.w1.weight"
        ].to(dtype)
        converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[
            f"layers.{layer_idx}.feed_forward.w2.weight"
        ].to(dtype)
        converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[
            f"layers.{layer_idx}.feed_forward.w3.weight"
        ].to(dtype)
        # rms norm
        converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"].to(dtype)
        converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"].to(dtype)
    return converted


shard_dims = {
    "lm_head.weight": 0,
    "wte.weight": 1,
    "attn.c_attn.weight": 0,
    "attn.c_proj.weight": 1,
    "mlp.c_fc1.weight": 0,
    "mlp.c_fc2.weight": 0,
    "mlp.c_proj.weight": 1
}


def meta_weights_for_nano_model(
    *,
    output_dir: Path = Path("checkpoints/lit-llama"),
    checkpoint_dir: Path = Path("checkpoints/llama/"),
    model_size: str = "7B",
    dtype: str = "float32",
) -> None:
    output_dir = output_dir / model_size
    checkpoint_dir = checkpoint_dir / model_size
    output_dir.mkdir(parents=True, exist_ok=True)

    # the tokenizer is the same for all model sizes, so we store it in the parent dir
    shutil.copy(checkpoint_dir.parent / "tokenizer.model", output_dir.parent)

    dt = getattr(torch, dtype, None)
    if not isinstance(dt, torch.dtype):
        raise ValueError(f"{dtype} is not a valid dtype.")
    dtype = dt

    checkpoint_files = sorted(checkpoint_dir.glob("*.pth"))
    checkpoint_files.sort()
    n_checkpoints = len(checkpoint_files)

    if n_checkpoints == 0:
        raise RuntimeError(f"No checkpoints were found at checkpoint_dir {checkpoint_dir}. `consolidated.0*.pth` files expected at that location.")

    # for the bigger models, there are multiple model-parallel checkpoints
    # and we combine them into one single file
    combined = None
    for file in tqdm(checkpoint_files, total=n_checkpoints):
        checkpoint = torch.load(file, map_location="cpu")
        converted = convert_state_dict(checkpoint, dtype=dtype)
        if combined is None:
            combined = converted
            continue
        for name, param in converted.items():
            dim = None
            for k, d in shard_dims.items():
                if k in name:
                    dim = d
                    break
            if dim is None:
                # Extra check: assert that tensors are the same if not sharded
                # assert torch.allclose(combined[name], param)
                continue
            combined[name] = torch.cat((combined[name], param), dim=dim)

        del checkpoint
        del converted
        gc.collect()

    for name, param in combined.items():
        if "c_attn" not in name:
            continue

        # Turn [Q1, K1, V1, Q2, K2, V2, ...] into [Q1, Q2, ..., K1, K2, .., V1, V2, ...]

        src_chunk_len = param.shape[0] // n_checkpoints
        mat_len = src_chunk_len // 3
        dst_chunk_len = mat_len * n_checkpoints
        attn = torch.clone(param)
        for i in range(n_checkpoints):
            for j in range(3):
                param[j * dst_chunk_len + i * mat_len: j * dst_chunk_len + (i+1) * mat_len] = \
                    attn[i * src_chunk_len + j * mat_len: i * src_chunk_len + (j+1) * mat_len]

        del attn
        gc.collect()

    torch.save(combined, output_dir / "lit-llama.pth")


if __name__ == "__main__":
    from jsonargparse import CLI

    CLI(meta_weights_for_nano_model)
